
<!DOCTYPE html>

<html lang="en">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    <title>tigramite.independence_tests.cmiknn &#8212; Tigramite 5.2 documentation</title>
    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
    <link rel="stylesheet" type="text/css" href="../../../_static/alabaster.css" />
    <script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
    <script src="../../../_static/jquery.js"></script>
    <script src="../../../_static/underscore.js"></script>
    <script src="../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    <script src="../../../_static/doctools.js"></script>
    <link rel="index" title="Index" href="../../../genindex.html" />
    <link rel="search" title="Search" href="../../../search.html" />
   
  <link rel="stylesheet" href="../../../_static/custom.css" type="text/css" />
  
  
  <meta name="viewport" content="width=device-width, initial-scale=0.9, maximum-scale=0.9" />

  </head><body>
  

    <div class="document">
      <div class="documentwrapper">
        <div class="bodywrapper">
          

          <div class="body" role="main">
            
  <h1>Source code for tigramite.independence_tests.cmiknn</h1><div class="highlight"><pre>
<span></span><span class="sd">&quot;&quot;&quot;Tigramite causal discovery for time series.&quot;&quot;&quot;</span>

<span class="c1"># Author: Jakob Runge &lt;jakob@jakob-runge.com&gt;</span>
<span class="c1">#</span>
<span class="c1"># License: GNU General Public License v3.0</span>

<span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span>
<span class="kn">from</span> <span class="nn">scipy</span> <span class="kn">import</span> <span class="n">special</span><span class="p">,</span> <span class="n">spatial</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">.independence_tests_base</span> <span class="kn">import</span> <span class="n">CondIndTest</span>
<span class="kn">from</span> <span class="nn">numba</span> <span class="kn">import</span> <span class="n">jit</span>
<span class="kn">import</span> <span class="nn">warnings</span>


<div class="viewcode-block" id="CMIknn"><a class="viewcode-back" href="../../../index.html#tigramite.independence_tests.cmiknn.CMIknn">[docs]</a><span class="k">class</span> <span class="nc">CMIknn</span><span class="p">(</span><span class="n">CondIndTest</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Conditional mutual information test based on nearest-neighbor estimator.</span>

<span class="sd">    Conditional mutual information is the most general dependency measure coming</span>
<span class="sd">    from an information-theoretic framework. It makes no assumptions about the</span>
<span class="sd">    parametric form of the dependencies by directly estimating the underlying</span>
<span class="sd">    joint density. The test here is based on the estimator in  S. Frenzel and B.</span>
<span class="sd">    Pompe, Phys. Rev. Lett. 99, 204101 (2007), combined with a shuffle test to</span>
<span class="sd">    generate  the distribution under the null hypothesis of independence first</span>
<span class="sd">    used in [3]_. The knn-estimator is suitable only for variables taking a</span>
<span class="sd">    continuous range of values. For discrete variables use the CMIsymb class.</span>

<span class="sd">    Notes</span>
<span class="sd">    -----</span>
<span class="sd">    CMI is given by</span>

<span class="sd">    .. math:: I(X;Y|Z) &amp;= \int p(z)  \iint  p(x,y|z) \log</span>
<span class="sd">                \frac{ p(x,y |z)}{p(x|z)\cdot p(y |z)} \,dx dy dz</span>

<span class="sd">    Its knn-estimator is given by</span>

<span class="sd">    .. math:: \widehat{I}(X;Y|Z)  &amp;=   \psi (k) + \frac{1}{T} \sum_{t=1}^T</span>
<span class="sd">            \left[ \psi(k_{Z,t}) - \psi(k_{XZ,t}) - \psi(k_{YZ,t}) \right]</span>

<span class="sd">    where :math:`\psi` is the Digamma function.  This estimator has as a</span>
<span class="sd">    parameter the number of nearest-neighbors :math:`k` which determines the</span>
<span class="sd">    size of hyper-cubes around each (high-dimensional) sample point. Then</span>
<span class="sd">    :math:`k_{Z,},k_{XZ},k_{YZ}` are the numbers of neighbors in the respective</span>
<span class="sd">    subspaces.</span>

<span class="sd">    :math:`k` can be viewed as a density smoothing parameter (although it is</span>
<span class="sd">    data-adaptive unlike fixed-bandwidth estimators). For large :math:`k`, the</span>
<span class="sd">    underlying dependencies are more smoothed and CMI has a larger bias,</span>
<span class="sd">    but lower variance, which is more important for significance testing. Note</span>
<span class="sd">    that the estimated CMI values can be slightly negative while CMI is a non-</span>
<span class="sd">    negative quantity.</span>

<span class="sd">    This method requires the scipy.spatial.cKDTree package.</span>

<span class="sd">    References</span>
<span class="sd">    ----------</span>

<span class="sd">    .. [3] J. Runge (2018): Conditional Independence Testing Based on a</span>
<span class="sd">           Nearest-Neighbor Estimator of Conditional Mutual Information.</span>
<span class="sd">           In Proceedings of the 21st International Conference on Artificial</span>
<span class="sd">           Intelligence and Statistics.</span>
<span class="sd">           http://proceedings.mlr.press/v84/runge18a.html</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    knn : int or float, optional (default: 0.2)</span>
<span class="sd">        Number of nearest-neighbors which determines the size of hyper-cubes</span>
<span class="sd">        around each (high-dimensional) sample point. If smaller than 1, this is</span>
<span class="sd">        computed as a fraction of T, hence knn=knn*T. For knn larger or equal to</span>
<span class="sd">        1, this is the absolute number.</span>

<span class="sd">    shuffle_neighbors : int, optional (default: 5)</span>
<span class="sd">        Number of nearest-neighbors within Z for the shuffle surrogates which</span>
<span class="sd">        determines the size of hyper-cubes around each (high-dimensional) sample</span>
<span class="sd">        point.</span>

<span class="sd">    transform : {&#39;ranks&#39;, &#39;standardize&#39;,  &#39;uniform&#39;, False}, optional</span>
<span class="sd">        (default: &#39;ranks&#39;)</span>
<span class="sd">        Whether to transform the array beforehand by standardizing</span>
<span class="sd">        or transforming to uniform marginals.</span>

<span class="sd">    workers : int (optional, default = -1)</span>
<span class="sd">        Number of workers to use for parallel processing. If -1 is given</span>
<span class="sd">        all processors are used. Default: -1.</span>

<span class="sd">    model_selection_folds : int (optional, default = 3)</span>
<span class="sd">        Number of folds in cross-validation used in model selection.</span>

<span class="sd">    significance : str, optional (default: &#39;shuffle_test&#39;)</span>
<span class="sd">        Type of significance test to use. For CMIknn only &#39;fixed_thres&#39; and</span>
<span class="sd">        &#39;shuffle_test&#39; are available.</span>

<span class="sd">    **kwargs :</span>
<span class="sd">        Arguments passed on to parent class CondIndTest.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">measure</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Concrete property to return the measure of the independence test</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_measure</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                 <span class="n">knn</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span>
                 <span class="n">shuffle_neighbors</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
                 <span class="n">significance</span><span class="o">=</span><span class="s1">&#39;shuffle_test&#39;</span><span class="p">,</span>
                 <span class="n">transform</span><span class="o">=</span><span class="s1">&#39;ranks&#39;</span><span class="p">,</span>
                 <span class="n">workers</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
                 <span class="n">model_selection_folds</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
                 <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="c1"># Set the member variables</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">knn</span> <span class="o">=</span> <span class="n">knn</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">shuffle_neighbors</span> <span class="o">=</span> <span class="n">shuffle_neighbors</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">=</span> <span class="n">transform</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_measure</span> <span class="o">=</span> <span class="s1">&#39;cmi_knn&#39;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">two_sided</span> <span class="o">=</span> <span class="kc">False</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">residual_based</span> <span class="o">=</span> <span class="kc">False</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">recycle_residuals</span> <span class="o">=</span> <span class="kc">False</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">workers</span> <span class="o">=</span> <span class="n">workers</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model_selection_folds</span> <span class="o">=</span> <span class="n">model_selection_folds</span>
        <span class="c1"># Call the parent constructor</span>
        <span class="n">CondIndTest</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">significance</span><span class="o">=</span><span class="n">significance</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
        <span class="c1"># Print some information about construction</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">knn</span> <span class="o">&lt;</span> <span class="mi">1</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;knn/T = </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">knn</span><span class="p">)</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;knn = </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">knn</span><span class="p">)</span>
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;shuffle_neighbors = </span><span class="si">%d</span><span class="se">\n</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">shuffle_neighbors</span><span class="p">)</span>

    <span class="nd">@jit</span><span class="p">(</span><span class="n">forceobj</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    <span class="k">def</span> <span class="nf">_get_nearest_neighbors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">array</span><span class="p">,</span> <span class="n">xyz</span><span class="p">,</span> <span class="n">knn</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns nearest neighbors according to Frenzel and Pompe (2007).</span>

<span class="sd">        Retrieves the distances eps to the k-th nearest neighbors for every</span>
<span class="sd">        sample in joint space XYZ and returns the numbers of nearest neighbors</span>
<span class="sd">        within eps in subspaces Z, XZ, YZ.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        array : array-like</span>
<span class="sd">            data array with X, Y, Z in rows and observations in columns</span>

<span class="sd">        xyz : array of ints</span>
<span class="sd">            XYZ identifier array of shape (dim,).</span>

<span class="sd">        knn : int or float</span>
<span class="sd">            Number of nearest-neighbors which determines the size of hyper-cubes</span>
<span class="sd">            around each (high-dimensional) sample point. If smaller than 1, this</span>
<span class="sd">            is computed as a fraction of T, hence knn=knn*T. For knn larger or</span>
<span class="sd">            equal to 1, this is the absolute number.</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        k_xz, k_yz, k_z : tuple of arrays of shape (T,)</span>
<span class="sd">            Nearest neighbors in subspaces.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">array</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
        <span class="n">xyz</span> <span class="o">=</span> <span class="n">xyz</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>

        <span class="n">dim</span><span class="p">,</span> <span class="n">T</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">shape</span>

        <span class="c1"># Add noise to destroy ties...</span>
        <span class="n">array</span> <span class="o">+=</span> <span class="p">(</span><span class="mf">1E-6</span> <span class="o">*</span> <span class="n">array</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
                  <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">random_state</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="n">array</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">array</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])))</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">==</span> <span class="s1">&#39;standardize&#39;</span><span class="p">:</span>
            <span class="c1"># Standardize</span>
            <span class="n">array</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
            <span class="n">array</span> <span class="o">-=</span> <span class="n">array</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
            <span class="n">std</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">dim</span><span class="p">):</span>
                <span class="k">if</span> <span class="n">std</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">!=</span> <span class="mf">0.</span><span class="p">:</span>
                    <span class="n">array</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">/=</span> <span class="n">std</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
            <span class="c1"># array /= array.std(axis=1).reshape(dim, 1)</span>
            <span class="c1"># FIXME: If the time series is constant, return nan rather than</span>
            <span class="c1"># raising Exception</span>
            <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">std</span> <span class="o">==</span> <span class="mf">0.</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;Possibly constant array!&quot;</span><span class="p">)</span>
                <span class="c1"># raise ValueError(&quot;nans after standardizing, &quot;</span>
                <span class="c1">#                  &quot;possibly constant array!&quot;)</span>
        <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">==</span> <span class="s1">&#39;uniform&#39;</span><span class="p">:</span>
            <span class="n">array</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_trafo2uniform</span><span class="p">(</span><span class="n">array</span><span class="p">)</span>
        <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">==</span> <span class="s1">&#39;ranks&#39;</span><span class="p">:</span>
            <span class="n">array</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>

        <span class="n">array</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">T</span>
        <span class="n">tree_xyz</span> <span class="o">=</span> <span class="n">spatial</span><span class="o">.</span><span class="n">cKDTree</span><span class="p">(</span><span class="n">array</span><span class="p">)</span>
        <span class="n">epsarray</span> <span class="o">=</span> <span class="n">tree_xyz</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">array</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="p">[</span><span class="n">knn</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">p</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
                                  <span class="n">eps</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>

        <span class="c1"># To search neighbors &lt; eps</span>
        <span class="n">epsarray</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">epsarray</span><span class="p">,</span> <span class="mf">0.99999</span><span class="p">)</span>

        <span class="c1"># Subsample indices</span>
        <span class="n">x_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">y_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">z_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>

        <span class="c1"># Find nearest neighbors in subspaces</span>
        <span class="n">xz</span> <span class="o">=</span> <span class="n">array</span><span class="p">[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">x_indices</span><span class="p">,</span> <span class="n">z_indices</span><span class="p">))]</span>
        <span class="n">tree_xz</span> <span class="o">=</span> <span class="n">spatial</span><span class="o">.</span><span class="n">cKDTree</span><span class="p">(</span><span class="n">xz</span><span class="p">)</span>
        <span class="n">k_xz</span> <span class="o">=</span> <span class="n">tree_xz</span><span class="o">.</span><span class="n">query_ball_point</span><span class="p">(</span><span class="n">xz</span><span class="p">,</span> <span class="n">r</span><span class="o">=</span><span class="n">epsarray</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span> <span class="n">workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">,</span> <span class="n">return_length</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>

        <span class="n">yz</span> <span class="o">=</span> <span class="n">array</span><span class="p">[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">y_indices</span><span class="p">,</span> <span class="n">z_indices</span><span class="p">))]</span>
        <span class="n">tree_yz</span> <span class="o">=</span> <span class="n">spatial</span><span class="o">.</span><span class="n">cKDTree</span><span class="p">(</span><span class="n">yz</span><span class="p">)</span>
        <span class="n">k_yz</span> <span class="o">=</span> <span class="n">tree_yz</span><span class="o">.</span><span class="n">query_ball_point</span><span class="p">(</span><span class="n">yz</span><span class="p">,</span> <span class="n">r</span><span class="o">=</span><span class="n">epsarray</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span> <span class="n">workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">,</span> <span class="n">return_length</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>

        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">z_indices</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">z</span> <span class="o">=</span> <span class="n">array</span><span class="p">[:,</span> <span class="n">z_indices</span><span class="p">]</span>
            <span class="n">tree_z</span> <span class="o">=</span> <span class="n">spatial</span><span class="o">.</span><span class="n">cKDTree</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
            <span class="n">k_z</span> <span class="o">=</span> <span class="n">tree_z</span><span class="o">.</span><span class="n">query_ball_point</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">r</span><span class="o">=</span><span class="n">epsarray</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span> <span class="n">workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">,</span> <span class="n">return_length</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="c1"># Number of neighbors is T when z is empty.</span>
            <span class="n">k_z</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">k_xz</span><span class="p">,</span> <span class="n">k_yz</span><span class="p">,</span> <span class="n">k_z</span>

<div class="viewcode-block" id="CMIknn.get_dependence_measure"><a class="viewcode-back" href="../../../index.html#tigramite.independence_tests.cmiknn.CMIknn.get_dependence_measure">[docs]</a>    <span class="k">def</span> <span class="nf">get_dependence_measure</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">array</span><span class="p">,</span> <span class="n">xyz</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns CMI estimate as described in Frenzel and Pompe PRL (2007).</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        array : array-like</span>
<span class="sd">            data array with X, Y, Z in rows and observations in columns</span>

<span class="sd">        xyz : array of ints</span>
<span class="sd">            XYZ identifier array of shape (dim,).</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        val : float</span>
<span class="sd">            Conditional mutual information estimate.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">dim</span><span class="p">,</span> <span class="n">T</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">shape</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">knn</span> <span class="o">&lt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">knn_here</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">knn</span><span class="o">*</span><span class="n">T</span><span class="p">))</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">knn_here</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">knn</span><span class="p">))</span>


        <span class="n">k_xz</span><span class="p">,</span> <span class="n">k_yz</span><span class="p">,</span> <span class="n">k_z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_nearest_neighbors</span><span class="p">(</span><span class="n">array</span><span class="o">=</span><span class="n">array</span><span class="p">,</span>
                                                      <span class="n">xyz</span><span class="o">=</span><span class="n">xyz</span><span class="p">,</span>
                                                      <span class="n">knn</span><span class="o">=</span><span class="n">knn_here</span><span class="p">)</span>

        <span class="n">val</span> <span class="o">=</span> <span class="n">special</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">knn_here</span><span class="p">)</span> <span class="o">-</span> <span class="p">(</span><span class="n">special</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">k_xz</span><span class="p">)</span> <span class="o">+</span>
                                           <span class="n">special</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">k_yz</span><span class="p">)</span> <span class="o">-</span>
                                           <span class="n">special</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">k_z</span><span class="p">))</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>

        <span class="k">return</span> <span class="n">val</span></div>


<div class="viewcode-block" id="CMIknn.get_shuffle_significance"><a class="viewcode-back" href="../../../index.html#tigramite.independence_tests.cmiknn.CMIknn.get_shuffle_significance">[docs]</a>    <span class="k">def</span> <span class="nf">get_shuffle_significance</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">array</span><span class="p">,</span> <span class="n">xyz</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span>
                                 <span class="n">return_null_dist</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns p-value for nearest-neighbor shuffle significance test.</span>

<span class="sd">        For non-empty Z, overwrites get_shuffle_significance from the parent</span>
<span class="sd">        class  which is a block shuffle test, which does not preserve</span>
<span class="sd">        dependencies of X and Y with Z. Here the parameter shuffle_neighbors is</span>
<span class="sd">        used to permute only those values :math:`x_i` and :math:`x_j` for which</span>
<span class="sd">        :math:`z_j` is among the nearest niehgbors of :math:`z_i`. If Z is</span>
<span class="sd">        empty, the block-shuffle test is used.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        array : array-like</span>
<span class="sd">            data array with X, Y, Z in rows and observations in columns</span>

<span class="sd">        xyz : array of ints</span>
<span class="sd">            XYZ identifier array of shape (dim,).</span>

<span class="sd">        value : number</span>
<span class="sd">            Value of test statistic for unshuffled estimate.</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        pval : float</span>
<span class="sd">            p-value</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="n">dim</span><span class="p">,</span> <span class="n">T</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">shape</span>

        <span class="c1"># Skip shuffle test if value is above threshold</span>
        <span class="c1"># if value &gt; self.minimum threshold:</span>
        <span class="c1">#     if return_null_dist:</span>
        <span class="c1">#         return 0., None</span>
        <span class="c1">#     else:</span>
        <span class="c1">#         return 0.</span>

        <span class="c1"># max_neighbors = max(1, int(max_neighbor_ratio*T))</span>
        <span class="n">x_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">z_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>

        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">z_indices</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">shuffle_neighbors</span> <span class="o">&lt;</span> <span class="n">T</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">2</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;            nearest-neighbor shuffle significance &quot;</span>
                      <span class="s2">&quot;test with n = </span><span class="si">%d</span><span class="s2"> and </span><span class="si">%d</span><span class="s2"> surrogates&quot;</span> <span class="o">%</span> <span class="p">(</span>
                      <span class="bp">self</span><span class="o">.</span><span class="n">shuffle_neighbors</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sig_samples</span><span class="p">))</span>

            <span class="c1"># Get nearest neighbors around each sample point in Z</span>
            <span class="n">z_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">fastCopyAndTranspose</span><span class="p">(</span><span class="n">array</span><span class="p">[</span><span class="n">z_indices</span><span class="p">,</span> <span class="p">:])</span>
            <span class="n">tree_xyz</span> <span class="o">=</span> <span class="n">spatial</span><span class="o">.</span><span class="n">cKDTree</span><span class="p">(</span><span class="n">z_array</span><span class="p">)</span>
            <span class="n">neighbors</span> <span class="o">=</span> <span class="n">tree_xyz</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">z_array</span><span class="p">,</span>
                                       <span class="n">k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">shuffle_neighbors</span><span class="p">,</span>
                                       <span class="n">p</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
                                       <span class="n">eps</span><span class="o">=</span><span class="mf">0.</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>

            <span class="n">null_dist</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sig_samples</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">sam</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sig_samples</span><span class="p">):</span>

                <span class="c1"># Generate random order in which to go through indices loop in</span>
                <span class="c1"># next step</span>
                <span class="n">order</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">random_state</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">T</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>

                <span class="c1"># Shuffle neighbor indices for each sample index</span>
                <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">neighbors</span><span class="p">)):</span>
                    <span class="bp">self</span><span class="o">.</span><span class="n">random_state</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">neighbors</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
                <span class="c1"># neighbors = self.random_state.permuted(neighbors, axis=1)</span>
                
                <span class="c1"># Select a series of neighbor indices that contains as few as</span>
                <span class="c1"># possible duplicates</span>
                <span class="n">restricted_permutation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_restricted_permutation</span><span class="p">(</span>
                        <span class="n">T</span><span class="o">=</span><span class="n">T</span><span class="p">,</span>
                        <span class="n">shuffle_neighbors</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">shuffle_neighbors</span><span class="p">,</span>
                        <span class="n">neighbors</span><span class="o">=</span><span class="n">neighbors</span><span class="p">,</span>
                        <span class="n">order</span><span class="o">=</span><span class="n">order</span><span class="p">)</span>

                <span class="n">array_shuffled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">array</span><span class="p">)</span>
                <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">x_indices</span><span class="p">:</span>
                    <span class="n">array_shuffled</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">array</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">restricted_permutation</span><span class="p">]</span>

                <span class="n">null_dist</span><span class="p">[</span><span class="n">sam</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_dependence_measure</span><span class="p">(</span><span class="n">array_shuffled</span><span class="p">,</span>
                                                             <span class="n">xyz</span><span class="p">)</span>

        <span class="k">else</span><span class="p">:</span>
            <span class="n">null_dist</span> <span class="o">=</span> \
                    <span class="bp">self</span><span class="o">.</span><span class="n">_get_shuffle_dist</span><span class="p">(</span><span class="n">array</span><span class="p">,</span> <span class="n">xyz</span><span class="p">,</span>
                                           <span class="bp">self</span><span class="o">.</span><span class="n">get_dependence_measure</span><span class="p">,</span>
                                           <span class="n">sig_samples</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sig_samples</span><span class="p">,</span>
                                           <span class="n">sig_blocklength</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sig_blocklength</span><span class="p">,</span>
                                           <span class="n">verbosity</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span><span class="p">)</span>

        <span class="n">pval</span> <span class="o">=</span> <span class="p">(</span><span class="n">null_dist</span> <span class="o">&gt;=</span> <span class="n">value</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>

        <span class="k">if</span> <span class="n">return_null_dist</span><span class="p">:</span>
            <span class="c1"># Sort</span>
            <span class="n">null_dist</span><span class="o">.</span><span class="n">sort</span><span class="p">()</span>
            <span class="k">return</span> <span class="n">pval</span><span class="p">,</span> <span class="n">null_dist</span>
        <span class="k">return</span> <span class="n">pval</span></div>


<div class="viewcode-block" id="CMIknn.get_conditional_entropy"><a class="viewcode-back" href="../../../index.html#tigramite.independence_tests.cmiknn.CMIknn.get_conditional_entropy">[docs]</a>    <span class="k">def</span> <span class="nf">get_conditional_entropy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">array</span><span class="p">,</span> <span class="n">xyz</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns the nearest-neighbor conditional entropy estimate of H(X|Y).</span>

<span class="sd">        Parameters</span>
<span class="sd">        ---------- </span>
<span class="sd">        array : array-like</span>
<span class="sd">            data array with X, Y in rows and observations in columns</span>

<span class="sd">        xyz : array of ints</span>
<span class="sd">            XYZ identifier array of shape (dim,). Here only uses 0 for X and </span>
<span class="sd">            1 for Y.</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        val : float</span>
<span class="sd">            Entropy estimate.</span>
<span class="sd">        &quot;&quot;&quot;</span>


        <span class="n">dim</span><span class="p">,</span> <span class="n">T</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">shape</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">knn</span> <span class="o">&lt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">knn_here</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">knn</span><span class="o">*</span><span class="n">T</span><span class="p">))</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">knn_here</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">knn</span><span class="p">))</span>


        <span class="n">array</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>

        <span class="c1"># Add noise to destroy ties...</span>
        <span class="n">array</span> <span class="o">+=</span> <span class="p">(</span><span class="mf">1E-6</span> <span class="o">*</span> <span class="n">array</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
                  <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">random_state</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="n">array</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">array</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])))</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">==</span> <span class="s1">&#39;standardize&#39;</span><span class="p">:</span>
            <span class="c1"># Standardize</span>
            <span class="n">array</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
            <span class="n">array</span> <span class="o">-=</span> <span class="n">array</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
            <span class="n">std</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">dim</span><span class="p">):</span>
                <span class="k">if</span> <span class="n">std</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">!=</span> <span class="mf">0.</span><span class="p">:</span>
                    <span class="n">array</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">/=</span> <span class="n">std</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
            <span class="c1"># array /= array.std(axis=1).reshape(dim, 1)</span>
            <span class="c1"># FIXME: If the time series is constant, return nan rather than</span>
            <span class="c1"># raising Exception</span>
            <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">std</span> <span class="o">==</span> <span class="mf">0.</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;Possibly constant array!&quot;</span><span class="p">)</span>
            <span class="c1"># if np.isnan(array).sum() != 0:</span>
            <span class="c1">#     raise ValueError(&quot;nans after standardizing, &quot;</span>
            <span class="c1">#                      &quot;possibly constant array!&quot;)</span>
        <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">==</span> <span class="s1">&#39;uniform&#39;</span><span class="p">:</span>
            <span class="n">array</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_trafo2uniform</span><span class="p">(</span><span class="n">array</span><span class="p">)</span>
        <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">==</span> <span class="s1">&#39;ranks&#39;</span><span class="p">:</span>
            <span class="n">array</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>

        <span class="c1"># Compute conditional entropy as H(X|Y) = H(X) - I(X;Y)</span>

        <span class="c1"># First compute H(X)</span>
        <span class="c1"># Use cKDTree to get distances eps to the k-th nearest neighbors for</span>
        <span class="c1"># every sample in joint space X with maximum norm</span>
        <span class="n">x_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">y_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>

        <span class="n">dim_x</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
        <span class="k">if</span> <span class="mi">1</span> <span class="ow">in</span> <span class="n">xyz</span><span class="p">:</span>
            <span class="n">dim_y</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">dim_x</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">dim_y</span> <span class="o">=</span> <span class="mi">0</span>


        <span class="n">x_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">fastCopyAndTranspose</span><span class="p">(</span><span class="n">array</span><span class="p">[</span><span class="n">x_indices</span><span class="p">,</span> <span class="p">:])</span>
        <span class="n">tree_xyz</span> <span class="o">=</span> <span class="n">spatial</span><span class="o">.</span><span class="n">cKDTree</span><span class="p">(</span><span class="n">x_array</span><span class="p">)</span>
        <span class="n">epsarray</span> <span class="o">=</span> <span class="n">tree_xyz</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">x_array</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="p">[</span><span class="n">knn_here</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">p</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
                                  <span class="n">eps</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>

        <span class="n">h_x</span> <span class="o">=</span> <span class="o">-</span> <span class="n">special</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">knn_here</span><span class="p">)</span> <span class="o">+</span> <span class="n">special</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">T</span><span class="p">)</span> <span class="o">+</span> <span class="n">dim_x</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mf">2.</span><span class="o">*</span><span class="n">epsarray</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>

        <span class="c1"># Then compute MI(X;Y)</span>
        <span class="k">if</span> <span class="n">dim_y</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">xyz_here</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">index</span> <span class="k">for</span> <span class="n">index</span> <span class="ow">in</span> <span class="n">xyz</span> <span class="k">if</span> <span class="n">index</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">index</span> <span class="o">==</span> <span class="mi">1</span><span class="p">])</span>
            <span class="n">array_xy</span> <span class="o">=</span> <span class="n">array</span><span class="p">[</span><span class="nb">list</span><span class="p">(</span><span class="n">x_indices</span><span class="p">)</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">y_indices</span><span class="p">),</span> <span class="p">:]</span>
            <span class="n">i_xy</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_dependence_measure</span><span class="p">(</span><span class="n">array_xy</span><span class="p">,</span> <span class="n">xyz_here</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">i_xy</span> <span class="o">=</span> <span class="mf">0.</span>

        <span class="n">h_x_y</span> <span class="o">=</span> <span class="n">h_x</span> <span class="o">-</span> <span class="n">i_xy</span>

        <span class="k">return</span> <span class="n">h_x_y</span></div>


    <span class="nd">@jit</span><span class="p">(</span><span class="n">forceobj</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    <span class="k">def</span> <span class="nf">get_restricted_permutation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">shuffle_neighbors</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">order</span><span class="p">):</span>

        <span class="n">restricted_permutation</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
        <span class="n">used</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">sample_index</span> <span class="ow">in</span> <span class="n">order</span><span class="p">:</span>
            <span class="n">m</span> <span class="o">=</span> <span class="mi">0</span>
            <span class="n">use</span> <span class="o">=</span> <span class="n">neighbors</span><span class="p">[</span><span class="n">sample_index</span><span class="p">,</span> <span class="n">m</span><span class="p">]</span>

            <span class="k">while</span> <span class="p">((</span><span class="n">use</span> <span class="ow">in</span> <span class="n">used</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">m</span> <span class="o">&lt;</span> <span class="n">shuffle_neighbors</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)):</span>
                <span class="n">m</span> <span class="o">+=</span> <span class="mi">1</span>
                <span class="n">use</span> <span class="o">=</span> <span class="n">neighbors</span><span class="p">[</span><span class="n">sample_index</span><span class="p">,</span> <span class="n">m</span><span class="p">]</span>

            <span class="n">restricted_permutation</span><span class="p">[</span><span class="n">sample_index</span><span class="p">]</span> <span class="o">=</span> <span class="n">use</span>
            <span class="n">used</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">used</span><span class="p">,</span> <span class="n">use</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">restricted_permutation</span>

<div class="viewcode-block" id="CMIknn.get_model_selection_criterion"><a class="viewcode-back" href="../../../index.html#tigramite.independence_tests.cmiknn.CMIknn.get_model_selection_criterion">[docs]</a>    <span class="k">def</span> <span class="nf">get_model_selection_criterion</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">parents</span><span class="p">,</span> <span class="n">tau_max</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns a cross-validation-based score for nearest-neighbor estimates.</span>

<span class="sd">        Fits a nearest-neighbor model of the parents to variable j and returns</span>
<span class="sd">        the score. The lower, the better the fit. Here used to determine</span>
<span class="sd">        optimal hyperparameters in PCMCI(pc_alpha or fixed thres).</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        j : int</span>
<span class="sd">            Index of target variable in data array.</span>

<span class="sd">        parents : list</span>
<span class="sd">            List of form [(0, -1), (3, -2), ...] containing parents.</span>

<span class="sd">        tau_max : int, optional (default: 0)</span>
<span class="sd">            Maximum time lag. This may be used to make sure that estimates for</span>
<span class="sd">            different lags in X, Z, all have the same sample size.</span>

<span class="sd">        Returns:</span>
<span class="sd">        score : float</span>
<span class="sd">            Model score.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="kn">import</span> <span class="nn">sklearn</span>
        <span class="kn">from</span> <span class="nn">sklearn.neighbors</span> <span class="kn">import</span> <span class="n">KNeighborsRegressor</span>
        <span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">cross_val_score</span>

        <span class="n">Y</span> <span class="o">=</span> <span class="p">[(</span><span class="n">j</span><span class="p">,</span> <span class="mi">0</span><span class="p">)]</span>
        <span class="n">X</span> <span class="o">=</span> <span class="p">[(</span><span class="n">j</span><span class="p">,</span> <span class="mi">0</span><span class="p">)]</span>   <span class="c1"># dummy variable here</span>
        <span class="n">Z</span> <span class="o">=</span> <span class="n">parents</span>
        <span class="n">array</span><span class="p">,</span> <span class="n">xyz</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataframe</span><span class="o">.</span><span class="n">construct_array</span><span class="p">(</span><span class="n">X</span><span class="o">=</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="o">=</span><span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="o">=</span><span class="n">Z</span><span class="p">,</span>
                                                    <span class="n">tau_max</span><span class="o">=</span><span class="n">tau_max</span><span class="p">,</span>
                                                    <span class="n">mask_type</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_type</span><span class="p">,</span>
                                                    <span class="n">return_cleaned_xyz</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
                                                    <span class="n">do_checks</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
                                                    <span class="n">verbosity</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span><span class="p">)</span>
        <span class="n">dim</span><span class="p">,</span> <span class="n">T</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">shape</span>

        <span class="c1"># Standardize</span>
        <span class="n">array</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
        <span class="n">array</span> <span class="o">-=</span> <span class="n">array</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
        <span class="n">std</span> <span class="o">=</span> <span class="n">array</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">dim</span><span class="p">):</span>
            <span class="k">if</span> <span class="n">std</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">!=</span> <span class="mf">0.</span><span class="p">:</span>
                <span class="n">array</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">/=</span> <span class="n">std</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
        <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">std</span> <span class="o">==</span> <span class="mf">0.</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;Possibly constant array!&quot;</span><span class="p">)</span>
            <span class="c1"># raise ValueError(&quot;nans after standardizing, &quot;</span>
            <span class="c1">#                  &quot;possibly constant array!&quot;)</span>

        <span class="n">predictor_indices</span> <span class="o">=</span>  <span class="nb">list</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span><span class="o">==</span><span class="mi">2</span><span class="p">)[</span><span class="mi">0</span><span class="p">])</span>
        <span class="n">predictor_array</span> <span class="o">=</span> <span class="n">array</span><span class="p">[</span><span class="n">predictor_indices</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">T</span>
        <span class="c1"># Target is only first entry of Y, ie [y]</span>
        <span class="n">target_array</span> <span class="o">=</span> <span class="n">array</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">xyz</span><span class="o">==</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="p">:]</span>

        <span class="k">if</span> <span class="n">predictor_array</span><span class="o">.</span><span class="n">size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="c1"># Regressing on ones if empty parents</span>
            <span class="n">predictor_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">T</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">knn</span> <span class="o">&lt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">knn_here</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">knn</span><span class="o">*</span><span class="n">T</span><span class="p">))</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">knn_here</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">knn</span><span class="p">))</span>

        <span class="n">knn_model</span> <span class="o">=</span> <span class="n">KNeighborsRegressor</span><span class="p">(</span><span class="n">n_neighbors</span><span class="o">=</span><span class="n">knn_here</span><span class="p">)</span>
        
        <span class="n">scores</span> <span class="o">=</span> <span class="n">cross_val_score</span><span class="p">(</span><span class="n">estimator</span><span class="o">=</span><span class="n">knn_model</span><span class="p">,</span> 
            <span class="n">X</span><span class="o">=</span><span class="n">predictor_array</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">target_array</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model_selection_folds</span><span class="p">,</span> <span class="n">n_jobs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">)</span>
        
        <span class="c1"># print(scores)</span>
        <span class="k">return</span> <span class="o">-</span><span class="n">scores</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></div></div>

<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
    
    <span class="kn">import</span> <span class="nn">tigramite</span>
    <span class="kn">from</span> <span class="nn">tigramite.data_processing</span> <span class="kn">import</span> <span class="n">DataFrame</span>
    <span class="kn">import</span> <span class="nn">tigramite.data_processing</span> <span class="k">as</span> <span class="nn">pp</span>
    <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>

    <span class="n">random_state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">default_rng</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>
    <span class="n">cmi</span> <span class="o">=</span> <span class="n">CMIknn</span><span class="p">(</span><span class="n">mask_type</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                   <span class="n">significance</span><span class="o">=</span><span class="s1">&#39;fixed_thres&#39;</span><span class="p">,</span>
                   <span class="n">sig_samples</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span>
                   <span class="n">sig_blocklength</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
                   <span class="n">transform</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">,</span>
                   <span class="n">knn</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
                   <span class="n">verbosity</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

    <span class="n">T</span> <span class="o">=</span> <span class="mi">1000</span>
    <span class="n">dimz</span> <span class="o">=</span> <span class="mi">1</span>

    <span class="c1"># # Continuous data</span>
    <span class="c1"># z = random_state.standard_normal((T, dimz))</span>
    <span class="c1"># x = (1.*z[:,0] + random_state.standard_normal(T)).reshape(T, 1)</span>
    <span class="c1"># y = (1.*z[:,0] + random_state.standard_normal(T)).reshape(T, 1)</span>

    <span class="c1"># print(&#39;X _|_ Y&#39;)</span>
    <span class="c1"># print(cmi.run_test_raw(x, y, z=None))</span>
    <span class="c1"># print(&#39;X _|_ Y | Z&#39;)</span>
    <span class="c1"># print(cmi.run_test_raw(x, y, z=z))</span>

    <span class="c1"># Continuous data</span>
    <span class="n">z</span> <span class="o">=</span> <span class="n">random_state</span><span class="o">.</span><span class="n">standard_normal</span><span class="p">((</span><span class="n">T</span><span class="p">,</span> <span class="n">dimz</span><span class="p">))</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">random_state</span><span class="o">.</span><span class="n">standard_normal</span><span class="p">(</span><span class="n">T</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">y</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.</span><span class="o">*</span><span class="n">z</span><span class="p">[:,</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mf">1.</span><span class="o">*</span><span class="n">x</span><span class="p">[:,</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">random_state</span><span class="o">.</span><span class="n">standard_normal</span><span class="p">(</span><span class="n">T</span><span class="p">))</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

    <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">))</span>
    <span class="n">data</span><span class="p">[:,</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.5</span>
    <span class="nb">print</span> <span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
    <span class="n">dataframe</span> <span class="o">=</span> <span class="n">DataFrame</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">)</span>
    <span class="n">cmi</span><span class="o">.</span><span class="n">set_dataframe</span><span class="p">(</span><span class="n">dataframe</span><span class="p">)</span>
    <span class="nb">print</span><span class="p">(</span><span class="n">cmi</span><span class="o">.</span><span class="n">run_test</span><span class="p">(</span><span class="n">X</span><span class="o">=</span><span class="p">[(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)],</span> <span class="n">Y</span><span class="o">=</span><span class="p">[(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)],</span> <span class="n">alpha_or_thres</span><span class="o">=</span><span class="mf">0.5</span>  <span class="p">))</span>
    <span class="c1"># print(cmi.get_model_selection_criterion(j=1, parents=[], tau_max=0))</span>
    <span class="c1"># print(cmi.get_model_selection_criterion(j=1, parents=[(0, 0)], tau_max=0))</span>
    <span class="c1"># print(cmi.get_model_selection_criterion(j=1, parents=[(0, 0), (2, 0)], tau_max=0))</span>
</pre></div>

          </div>
          
        </div>
      </div>
      <div class="sphinxsidebar" role="navigation" aria-label="main navigation">
        <div class="sphinxsidebarwrapper">
<h1 class="logo"><a href="../../../index.html">Tigramite</a></h1>








<h3>Navigation</h3>

<div class="relations">
<h3>Related Topics</h3>
<ul>
  <li><a href="../../../index.html">Documentation overview</a><ul>
  <li><a href="../../index.html">Module code</a><ul>
  </ul></li>
  </ul></li>
</ul>
</div>
<div id="searchbox" style="display: none" role="search">
  <h3 id="searchlabel">Quick search</h3>
    <div class="searchformwrapper">
    <form class="search" action="../../../search.html" method="get">
      <input type="text" name="q" aria-labelledby="searchlabel" autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/>
      <input type="submit" value="Go" />
    </form>
    </div>
</div>
<script>document.getElementById('searchbox').style.display = "block"</script>








        </div>
      </div>
      <div class="clearer"></div>
    </div>
    <div class="footer">
      &copy;2023, Jakob Runge.
      
      |
      Powered by <a href="http://sphinx-doc.org/">Sphinx 5.0.2</a>
      &amp; <a href="https://github.com/bitprophet/alabaster">Alabaster 0.7.12</a>
      
    </div>

    

    
  </body>
</html>